public class Solution1277 {
    public int countSquares(int[][] matrix) {
        int[][] sum=new int[matrix.length+1][matrix[0].length+1];
        for (int i=0;i<matrix.length;i++){
            for (int j=0;j<matrix[0].length;j++){
                sum[i+1][j+1]=sum[i+1][j]+sum[i][j+1]+matrix[i][j]-sum[i][j];
            }
        }
        int ans=0;
        for (int i=0;i<matrix.length;i++){
            for (int j=i;j<matrix.length;j++){
                for (int k=0;k+j-i<matrix[0].length;k++){
                    if (sum[j+1][k+j-i+1]-sum[i][k+j-i+1]-sum[j+1][k]+sum[i][k]==(j-i+1)*(j-i+1)){
                        ans++;
                    }
                }
            }
        }
        return ans;
    }

    public static void main(String[] args) {
        System.out.println(new Solution1277().countSquares(new int[][]{{0,1,1,1},{1,1,1,1},{0,1,1,1}}));
    }
}
